{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "start_time": "2019-05-06T15:58:14.995Z"
    },
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "import time\n",
    "t1 = time.time()\n",
    "! rm -rf yolo\n",
    "! git clone https://gitlab.com/prerakmody/yolo.git\n",
    "# ! pip install gputil\n",
    "print (round(time.time() - t1), ' secs')\n",
    "\n",
    "import os\n",
    "import sys\n",
    "import pdb; # pdb.set_trace()\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline\n",
    "\n",
    "import torch\n",
    "from torch.utils.data import DataLoader\n",
    "import torchvision.transforms as transforms\n",
    "from tensorboardcolab import TensorBoardColab\n",
    "\n",
    "USE_GPU = torch.cuda.is_available()\n",
    "print (' - USE_GPU : ', USE_GPU)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "start_time": "2019-05-06T07:35:12.290Z"
    },
    "code_folding": [
     37
    ],
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "dir_main = './yolo'\n",
    "sys.path.append(dir_main)\n",
    "\n",
    "from src.dataloader import YoloDataset\n",
    "from src.nets import *\n",
    "from src.train import YOLOv1Train, YOLOv1Loss\n",
    "# from src.train import YOLOv1Loss\n",
    "torch.cuda.empty_cache()\n",
    "\n",
    "if __name__ == \"__main__\":\n",
    "    \n",
    "    if (1):\n",
    "        model_name = 'yolov1'\n",
    "        model = getYOLOv1(model_name)\n",
    "        print (' - 1. Net : ', model_name)\n",
    "    \n",
    "    if (1):\n",
    "        DEBUG         = False\n",
    "        LEARNING_RATE = 0.001\n",
    "        EPOCHS        = 120\n",
    "        BATCH_SIZE    = 32\n",
    "        criterion     = YOLOv1Loss('','',5,0.5)\n",
    "        optimizer     = 'SGD'\n",
    "        \n",
    "        print (' - 2. [PARAMS] Debug      : ', DEBUG)\n",
    "        print (' - 2. [PARAMS] BATCH_SIZE : ', BATCH_SIZE)\n",
    "        print (' - 2. [PARAMS] EPOCHS     : ', EPOCHS)\n",
    "        print (' - 2. [PARAMS] Optimizer  : ', optimizer)\n",
    "        print ('')\n",
    "        \n",
    "    if (1):\n",
    "        CHKP_DIR    = 'chkpoints3'\n",
    "        CHKP_NAME   = 'yolov1_epoch%.3d.pkl' % (69)\n",
    "        CHKP_EPOCHS = 1\n",
    "        CHKP_LOAD   = True\n",
    "        print (' - 2. [PARAMS_SAVE] CHKP_DIR  : ', CHKP_DIR)\n",
    "        print (' - 2. [PARAMS_SAVE] CHKP_NAME : ', CHKP_NAME)\n",
    "        \n",
    "    if (1):\n",
    "        dir_annotations  = os.path.join(dir_main, 'data/VOCdevkit_trainval/VOC2007')\n",
    "        file_annotations = os.path.join(dir_main,'data/VOCdevkit_trainval/VOC2007/anno_trainval.txt')\n",
    "        IMAGE_SIZE       = 448\n",
    "        IMAGE_GRID         = 7\n",
    "        flag_augm        = True\n",
    "        trainFlag        = True\n",
    "\n",
    "        YoloDatasetTrain = YoloDataset(dir_annotations, file_annotations\n",
    "                                    , trainFlag\n",
    "                                    , IMAGE_SIZE, IMAGE_GRID\n",
    "                                    , flag_augm\n",
    "                                    , transform = [transforms.ToTensor()] )\n",
    "        DataLoaderTrain = DataLoader(YoloDatasetTrain, batch_size=BATCH_SIZE, shuffle=False,num_workers=0)\n",
    "        \n",
    "            dir_annotations  = os.path.join(dir_main, 'data/VOCdevkit_test/VOC2007')\n",
    "            file_annotations = os.path.join(dir_main, 'data/VOCdevkit_test/VOC2007/anno_test.txt')\n",
    "            trainFlag        = False\n",
    "            YoloDatasetTest  = YoloDataset(dir_annotations, file_annotations\n",
    "                                        , trainFlag\n",
    "                                        , IMAGE_SIZE, IMAGE_GRID\n",
    "                                        , flag_augm\n",
    "                                        , transform = [transforms.ToTensor()] )\n",
    "            DataLoaderTest   = DataLoader(YoloDatasetTest, batch_size=BATCH_SIZE, shuffle=False,num_workers=0)\n",
    "        \n",
    "        print(' - 3. [TrainDataset] %d images' % (len(YoloDatasetTrain)))\n",
    "    \n",
    "    if (1):\n",
    "        print (' - 4. Logger ')\n",
    "        pass\n",
    "        LOGGER = TensorBoardColab()\n",
    "        print (' - 4. Logger : ', LOGGER)\n",
    "        \n",
    "    \n",
    "    if (1):\n",
    "        print (' - 5. Training')\n",
    "        print (' -------------------------------------------------------------------- ')\n",
    "        trainObj = YOLOv1Train()\n",
    "        trainObj.train(model, criterion, optimizer\n",
    "                            , DataLoaderTrain, DataLoaderTest\n",
    "                            , LEARNING_RATE, EPOCHS, BATCH_SIZE\n",
    "                            , USE_GPU, LOGGER\n",
    "                            , CHKP_LOAD, CHKP_DIR, CHKP_NAME, CHKP_EPOCHS\n",
    "                            , DEBUG)"
   ]
  }
 ],
 "metadata": {
  "anaconda-cloud": {},
  "kernelspec": {
   "display_name": "Python [default]",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.5"
  },
  "varInspector": {
   "cols": {
    "lenName": 16,
    "lenType": 16,
    "lenVar": 40
   },
   "kernels_config": {
    "python": {
     "delete_cmd_postfix": "",
     "delete_cmd_prefix": "del ",
     "library": "var_list.py",
     "varRefreshCmd": "print(var_dic_list())"
    },
    "r": {
     "delete_cmd_postfix": ") ",
     "delete_cmd_prefix": "rm(",
     "library": "var_list.r",
     "varRefreshCmd": "cat(var_dic_list()) "
    }
   },
   "types_to_exclude": [
    "module",
    "function",
    "builtin_function_or_method",
    "instance",
    "_Feature"
   ],
   "window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
